19-10-8-1

n500,2k10n\le 500, 2\le k\le 10

原题有加强

Solution

首先考虑如何判定一个数存在于aa

每次加上的数[1,k)\in [1, k),所以除个位外,其他每一位每次最多+1+1. 也就是说,主要就是要判断满足某条件下,个位是否合法

g[i][p][x][a]g[i][p][x][a]表示,从第i+1i+1位开始往上都填完了,前面的数位最大值为pp,个位为aa,要在第ii位填xx的话,aa会变成多少

直接处理gg并不好处理,考虑再记f[i][p][a]f[i][p][a]表示第i+1i+1位开始往上都填完了,前面的数位最大值为pp,个位为aa,要在i+1i+1位产生11的进位,aa会变成多少

gg是处理一次进位,ff处理某一位填某个数

这样就能判定一个数是否合法了


接着考虑进行计数dp,设dp[i][j][p][x]dp[i][j][p][x]表示当前到点ii,判定到第jj位,最大值为pp,个位为xx的方案数

枚举下一个点kk,有转移:dp[k][j1][max{p,d[k]}][g[j1][p][d[k][x]]dp[i][j][p][x]dp[k][j - 1][\max \{p, d[k]\}][g[j - 1][p][d[k][x]] \leftarrow dp[i][j][p][x]

直接转移是O(n3k2)O(n^3k^2)

将刷表转化为填表,发现上一个点是dfs序连续的一段区间(在ii的父亲后,在ii前面),于是可以前缀和优化到O(n2k2)O(n^2k^2)


需要注意,在进行转移的时候,状态不一定合法。所以需要预处理出合法的转移状态,具体见代码

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back
#define DEBUG(x) cout << #x << " = " << x << endl;

using namespace std;

typedef long long LL;
typedef pair <int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }
template <typename T> inline T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

inline void proc_status ()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) << endl;
}

const int MAXN = 500;
const int MAXK = 10;
const int MOD = 998244353;

inline void Add (int &a, int b) { if ((a += b) >= MOD) a -= MOD; }

int N, K, A[MAXN + 5];
vector <int> G[MAXN + 5];
int dfn[MAXN + 5], idfn[MAXN + 5], dfs_clock;
int fa[MAXN + 5];

inline void dfs (int x, int f)
{
fa[x] = f, idfn[dfn[x] = ++dfs_clock] = x;
for (int i = 0; i < G[x].size(); ++i)
{
int y = G[x][i];
if (y == f) continue;
dfs (y, x);
}
}

int f[MAXN + 5][MAXK + 5][MAXK + 5];
int g[MAXN + 5][MAXK + 5][MAXK + 5][MAXK + 5];
int can[MAXK + 5][MAXK + 5][MAXK + 5]; // 需要判断个位是否合法

inline void Init ()
{
for (int i = 1; i <= N; ++i) sort (G[i].begin(), G[i].end());
dfs (1, 0);

for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q) if (p > 0 || q > 0)
{
int now = q;
while (now < K) can[p][q][now] = 1, now += max (now, p);
f[0][p][q] = now % K;
}

for (int i = 1; i <= N; ++i)
for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q) if (p > 0 || q > 0)
{
int now = q;
for (int t = 0; t < K; ++t) now = f[i - 1][max (p, t)][now];
f[i][p][q] = now;
}

for (int i = N; i >= 1; --i)
for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q) if (p > 0 || q > 0)
{
int now = q;
g[i][p][0][q] = now;
for (int j = 1; j < K; ++j)
{
now = f[i - 1][max (j - 1, p)][now];
g[i][p][j][q] = now;
}
}
}

int Dp[MAXN + 5][MAXN + 5][MAXK + 5][MAXK + 5];
int prefix[MAXN + 5][MAXN + 5][MAXK + 5][MAXK + 5];

inline int init_sum (int i)
{
for (int j = 0; j <= N; ++j)
for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q)
if (!i) prefix[i][j][p][q] = Dp[i][j][p][q];
else prefix[i][j][p][q] = (prefix[i - 1][j][p][q] + Dp[i][j][p][q]) % MOD;
}

inline int get_sum (int l, int r, int j, int p, int q)
{
if (!l) return prefix[r][j][p][q];
return (prefix[r][j][p][q] - prefix[l - 1][j][p][q] + MOD) % MOD;
}

inline void Solve ()
{
Init ();

for (int i = 1; i <= N; ++i) Dp[0][i][0][1] = 1;
init_sum (0);

for (int i = 1; i <= N; ++i)
{
for (int j = 1; j <= N; ++j)
for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q) if (p > 0 || q > 0)
{
int x = idfn[i];
int l = dfn[fa[x]], r = i - 1;

if (j > 1) Add (Dp[i][j - 1][max (p, A[x])][g[j - 1][p][A[x]][q]], get_sum (l, r, j, p, q));
else if (can[p][q][A[x]]) Add (Dp[i][0][max (p, A[x])][q], get_sum (l, r, j, p, q));
}

init_sum (i);
}

int ans = 0;
for (int p = 0; p < K; ++p)
for (int q = 0; q < K; ++q)
Add (ans, prefix[N][0][p][q]);

cout << ans << endl;
}

inline void Input ()
{
N = read<int>(), K = read<int>();
for (int i = 1; i <= N; ++i) A[i] = read<int>();
for (int i = 1; i < N; ++i)
{
int x = read<int>(), y = read<int>();
G[x].pb (y);
G[y].pb (x);
}
}

int main()
{

freopen("buried.in", "r", stdin);
freopen("buried.out", "w", stdout);

Input ();
Solve ();

return 0;
}